Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for conv_transpose2d operation #1540

Merged
merged 1 commit into from
Jan 16, 2025

Conversation

jserbedzijaTT
Copy link
Contributor

@jserbedzijaTT jserbedzijaTT commented Dec 9, 2024

closes #1084

@nsmithtt
Copy link
Contributor

nsmithtt commented Dec 9, 2024

Adding @LPanosTT

Copy link
Contributor

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Clang-Tidy found issue(s) with the introduced code (1/1)

lib/Dialect/TTNN/Transforms/TTNNLayout.cpp Outdated Show resolved Hide resolved
@LPanosTT
Copy link
Contributor

LPanosTT commented Dec 9, 2024

Hey thanks for adding this. I have something to say about this op though. It seems as though some frontends reverse the order of the data in the kernel window for this op, and some do not. I.e PyTorch does (and thus TTNN does) and JAX does not. You will see that ttir.convolution has a window_reversal boolean attr as well. In order to model the cases in all frontends we need this attribute for conv_transpose2d in ttnn. Or for us to add ttir.reverse so we can consteval the window reversal away.

There is an issue to add window_reversal to ttnn: tenstorrent/tt-metal#15342

@LPanosTT
Copy link
Contributor

LPanosTT commented Dec 9, 2024

Also if you could add a pattern to lower ttir.convolution to ttir.conv_transpose2d that would be great. Check out the stablehlo spec for convolution, which ttir.convolution is meant to mimic to see how you can tell if a given convolution is a transposed convolution or not.

@jserbedzijaTT jserbedzijaTT force-pushed the jserbedzija/add_conv_transpose2d_operation branch from b000588 to 7b36217 Compare December 20, 2024 12:58
@jserbedzijaTT jserbedzijaTT force-pushed the jserbedzija/add_conv_transpose2d_operation branch 2 times, most recently from 683fb3b to 4ddde58 Compare December 23, 2024 11:06
@jserbedzijaTT
Copy link
Contributor Author

jserbedzijaTT commented Dec 24, 2024

Also if you could add a pattern to lower ttir.convolution to ttir.conv_transpose2d that would be great. Check out the stablehlo spec for convolution, which ttir.convolution is meant to mimic to see how you can tell if a given convolution is a transposed convolution or not.

I will merge this pr as is but I have opened an issue to track the things you mentioned: #1662

@jserbedzijaTT jserbedzijaTT force-pushed the jserbedzija/add_conv_transpose2d_operation branch 2 times, most recently from 62b9199 to 2837812 Compare December 24, 2024 10:32
@mtopalovicTT
Copy link
Contributor

Copy link
Contributor

@sdjordjevicTT sdjordjevicTT left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great change Joco, thanks, couple of comments inline.

lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp Outdated Show resolved Hide resolved
lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp Outdated Show resolved Hide resolved
lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp Outdated Show resolved Hide resolved
lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp Outdated Show resolved Hide resolved
lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp Show resolved Hide resolved
lib/Dialect/TTIR/IR/TTIROps.cpp Outdated Show resolved Hide resolved
lib/Dialect/TTNN/IR/TTNNOps.cpp Outdated Show resolved Hide resolved
lib/Dialect/TTNN/Transforms/TTNNLayout.cpp Outdated Show resolved Hide resolved
lib/Target/TTNN/TTNNToFlatbuffer.cpp Outdated Show resolved Hide resolved
@jserbedzijaTT jserbedzijaTT force-pushed the jserbedzija/add_conv_transpose2d_operation branch 8 times, most recently from 065d821 to e119bc3 Compare December 30, 2024 13:55
@jserbedzijaTT jserbedzijaTT force-pushed the jserbedzija/add_conv_transpose2d_operation branch from e119bc3 to 055d423 Compare January 3, 2025 13:54
Copy link
Contributor

@nsmithtt nsmithtt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome work!

@jserbedzijaTT jserbedzijaTT force-pushed the jserbedzija/add_conv_transpose2d_operation branch 2 times, most recently from 9806f22 to 0c200c0 Compare January 13, 2025 14:07
@jserbedzijaTT jserbedzijaTT force-pushed the jserbedzija/add_conv_transpose2d_operation branch 7 times, most recently from 1512fa3 to 63a697e Compare January 16, 2025 11:29
@jserbedzijaTT jserbedzijaTT force-pushed the jserbedzija/add_conv_transpose2d_operation branch from 63a697e to 6b2ac8e Compare January 16, 2025 11:29
@jserbedzijaTT jserbedzijaTT enabled auto-merge (squash) January 16, 2025 12:27
@jserbedzijaTT jserbedzijaTT merged commit 6a76e4a into main Jan 16, 2025
20 checks passed
@jserbedzijaTT jserbedzijaTT deleted the jserbedzija/add_conv_transpose2d_operation branch January 16, 2025 12:28
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Ops] Support for conv_transpose2d op (ttnn.conv_transpose2d)
6 participants